Rubin/Roman image access and processing tutorial¶
import¶
In [44]:
import os
import pickle
import numpy as np
import galsim
import importlib
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse, Circle
from matplotlib.collections import PatchCollection
%matplotlib inline
%config InlineBackend.figure_formats = ['svg']
import lsst.geom
import astropy.units as u
from astropy.wcs import WCS
import lsst.meas.algorithms as measAlg
import lsst.afw.math as afwMath
import lsst.afw.geom as afwGeom
import lsst.afw.detection as afwDet
import lsst.afw.table as afwTable
import lsst.afw.image as afwImage
from lsst.meas.algorithms.detection import SourceDetectionTask
from lsst.meas.deblender import SourceDeblendTask
from lsst.meas.base import SingleFrameMeasurementTask
import GCRCatalogs
from GCRCatalogs import GCRQuery
import desc_dc2_dm_data
In [3]:
from astropy.visualization import MinMaxInterval, PercentileInterval, ZScaleInterval, simple_norm
def plot(im_, interval='zscale', stretch='linear', title=None, fn=None, xlabel=None, ylabel=None, colorbar=True, cmap='jet', dpi=300, show=False, **kwargs):
if isinstance(im_, galsim.Image):
im_ = im_.array
if interval=='zscale':
interval_ = ZScaleInterval()
if interval=='minmax':
interval_ = MinMaxInterval()
if interval=='percentile':
interval_ = PercentileInterval()
vmin, vmax = interval_.get_limits(im_)
norm = simple_norm(im_, stretch=stretch, min_cut=vmin, max_cut=vmax)
f = plt.imshow(im_, norm=norm, origin='lower', cmap=cmap, **kwargs)
cb = plt.colorbar()
if not colorbar:
cb.remove()
if xlabel is not None:
plt.xlabel(xlabel)
if ylabel is not None:
plt.ylabel(ylabel)
if title is not None:
plt.title(title)
if fn is not None:
plt.savefig(fn, dpi=dpi)
if show:
plt.show()
return f
parameter setup¶
In [4]:
lsst_bands = ['r', 'i', 'z', 'y', 'g', 'u']
roman_bands = ['Y106', 'J129', 'H158', 'F184']
roman_bands_i = [ band[0] for band in roman_bands]
home = os.getenv("HOME")
roman_fn_dict = {band: home+'/roman_coadd/dc2_%s_54.24_-38.3.fits'%band for band in roman_bands}
roman_fn_psf_dict = {band: home+'/roman_coadd_psf/dc2_%s_54.24_-42.1_psf.fits'%band for band in roman_bands}
butler = desc_dc2_dm_data.get_butler('2.2i_dr6_wfd')
ra, dec = 54.28, -38.30 #center
rubin_scale = 0.2
roman_scale = 0.0575
stampsize = 150
scale_ratio = rubin_scale/roman_scale
Rubin image¶
In [5]:
def get_rubin_img(ra, dec, butler, butlerDataset, size, band):
skymap = butler.get('deepCoadd_skyMap')
radec = lsst.geom.SpherePoint(ra, dec, lsst.geom.degrees)
tractInfo = skymap.findTract(radec)
patchInfo = tractInfo.findPatch(radec)
tract, patch = tractInfo.getId(), "{},{}".format(*patchInfo.getIndex())
dataId = {"tract": tract, "patch": patch, "filter": band}
full_patch = butler.get(butlerDataset, dataId=dataId)
cutout_extent = lsst.geom.ExtentI(size, size)
exp = full_patch.getCutout(radec, cutout_extent)
return exp
In [6]:
imageDict_rubin = {band[0]:get_rubin_img(ra, dec, butler, 'deepCoadd', size=stampsize, band=band) for band in lsst_bands}
In [7]:
for band, img in imageDict_rubin.items():
bbox = img.getBBox()
extent = (bbox.getBeginX(),bbox.getEndX(),bbox.getBeginY(),bbox.getEndY())
fig, ax = plt.subplots(figsize=(8,8), dpi=300)
plot(img.getMaskedImage().getImage().array, cmap='gray', colorbar=False, extent=extent, title=band)
Roman image¶
In [8]:
def get_roman_image(ra, dec, fn_img, size, fn_psf):
radec = lsst.geom.SpherePoint(ra, dec, lsst.geom.degrees)
full_patch = lsst.afw.image.ExposureF.readFits(fn_img)
if fn_psf:
psf = measAlg.KernelPsf(afwMath.FixedKernel(afwImage.ImageD(fn_psf)))
full_patch.setPsf(psf)
cutout_extent = lsst.geom.ExtentI(size, size)
exp = full_patch.getCutout(radec, cutout_extent)
return exp
In [9]:
imageDict_roman = {band[0]:get_roman_image(ra, dec, fn_img=roman_fn_dict[band], fn_psf=roman_fn_psf_dict[band], size=stampsize*scale_ratio) for band in roman_bands}
In [10]:
for band, img in imageDict_roman.items():
bbox = img.getBBox()
extent = (bbox.getBeginX(),bbox.getEndX(),bbox.getBeginY(),bbox.getEndY())
fig, ax = plt.subplots(figsize=(8,8), dpi=300)
plot(img.getMaskedImage().getImage().array, cmap='gray', colorbar=False, extent=extent, title=band)
Image processing¶
In [11]:
## define pipeline: detection + single band deblender + measurement
config_detection = SourceDetectionTask.ConfigClass()
config_deblend = SourceDeblendTask.ConfigClass()
config_meas = SingleFrameMeasurementTask.ConfigClass()
config_deblend.propagateAllPeaks = True
config_deblend.maskPlanes=[]
schema = afwTable.SourceTable.makeMinimalSchema()
detectionTask = SourceDetectionTask(schema=schema, config=config_detection)
sourceDeblendTask = SourceDeblendTask(schema=schema, config=config_deblend)
measureTask = SingleFrameMeasurementTask(schema=schema, config=config_meas)
def (self, image):
exp = image.exp
tab = afwTable.SourceTable.make(self.schema)
## Note that exp will be modified after running detection (calexp)
detections = self.detectionTask.run(tab, exp, doSmooth=self.detection_dosmooth, sigma=self.detection_sigma)
sources = detections.sources
self.sourceDeblendTask.run(exp, sources) ##exp is now calexp
self.measureTask.measure(sources, exp)
return sources, detections
In [46]:
## process 'r' band and 'H' band
## r band
exp_r = imageDict_rubin['r']
tab_r = afwTable.SourceTable.make(schema)
detections_r = detectionTask.run(tab_r, exp_r, doSmooth=True, sigma=None)
sources_r = detections_r.sources
sourceDeblendTask.run(exp_r, sources_r)
measureTask.measure(sources_r, exp_r)
## H band
exp_H = imageDict_roman['H']
tab_H = afwTable.SourceTable.make(schema)
detections_H = detectionTask.run(tab_H, exp_H, doSmooth=True, sigma=None)
sources_H = detections_H.sources
sourceDeblendTask.run(exp_H, sources_H)
measureTask.measure(sources_H, exp_H)
## output: detections, sources
plot sources and peaks¶
In [43]:
bbox = exp_r.getBBox()
extent = (bbox.getBeginX(),bbox.getEndX(),bbox.getBeginY(),bbox.getEndY())
fig, ax = plt.subplots(figsize=(8,8), dpi=300)
#draw img
plot(exp_r.getMaskedImage().getImage().array, cmap='gray', colorbar=False, extent=extent, title='r')
#draw peaks
px=[]
py=[]
for sr in sources_r:
fp = sr.getFootprint()
for pp in fp.getPeaks():
px.append(pp.getFx())
py.append(pp.getFy())
plt.scatter(px, py, c='#142c8c', marker='+', linewidths=0.8)
#draw ellipses with measurement results
flag = (sources_r['deblend_nChild']>0) | sources_r['base_PixelFlags_flag']
sources = sources_r[~flag]
x = sources['base_SdssCentroid_x']
y = sources['base_SdssCentroid_y']
axes = [ afwGeom.ellipses.Axes(s.getShape()) for s in sources]
size_scale = 3.0
ellipses = [Ellipse( (x[i], y[i]),
width =axes[i].getA()*size_scale,
height =axes[i].getB()*size_scale,
angle =np.rad2deg(axes[i].getTheta() ) ) for i in range(len(x))]
collection = PatchCollection(ellipses, edgecolor='r', facecolor='None')
ax.add_collection(collection)
plt.tight_layout()
plt.show()
In [49]:
bbox = exp_H.getBBox()
extent = (bbox.getBeginX(),bbox.getEndX(),bbox.getBeginY(),bbox.getEndY())
fig, ax = plt.subplots(figsize=(8,8), dpi=300)
#draw img
plot(exp_H.getMaskedImage().getImage().array, cmap='gray', colorbar=False, extent=extent, title='H')
#draw peaks
px=[]
py=[]
for sr in sources_H:
fp = sr.getFootprint()
for pp in fp.getPeaks():
px.append(pp.getFx())
py.append(pp.getFy())
plt.scatter(px, py, c='#142c8c', marker='+', linewidths=0.8)
#draw ellipses with measurement results
flag = (sources_H['deblend_nChild']>0) | sources_H['base_PixelFlags_flag']
sources = sources_H[~flag]
x = sources['base_SdssCentroid_x']
y = sources['base_SdssCentroid_y']
axes = [ afwGeom.ellipses.Axes(s.getShape()) for s in sources]
size_scale = 3.0*scale_ratio
ellipses = [Ellipse( (x[i], y[i]),
width =axes[i].getA()*size_scale,
height =axes[i].getB()*size_scale,
angle =np.rad2deg(axes[i].getTheta() ) ) for i in range(len(x))]
collection = PatchCollection(ellipses, edgecolor='r', facecolor='None')
ax.add_collection(collection)
plt.tight_layout()
plt.show()